package com.opencee.cloud.msg.thread;

import com.alibaba.fastjson.JSONObject;
import com.baomidou.mybatisplus.core.toolkit.IdWorker;
import com.opencee.cloud.msg.api.constatns.MsgTaskStatus;
import com.opencee.cloud.msg.api.constatns.MsgTaskType;
import com.opencee.cloud.msg.api.entity.MsgTaskEntity;
import com.opencee.cloud.msg.api.vo.MsgChannelConfigVO;
import com.opencee.cloud.msg.api.vo.params.MessageTaskParams;
import com.opencee.cloud.msg.config.MessageTaskContext;
import com.opencee.cloud.msg.service.MsgApplicationChannelService;
import com.opencee.cloud.msg.service.MsgTaskService;
import jodd.util.concurrent.ThreadFactoryBuilder;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.Assert;

import java.util.Collection;
import java.util.Date;
import java.util.List;
import java.util.Map;
import java.util.concurrent.*;

/**
 * 消息推送主进程
 *
 * @author yadu
 */
@Slf4j
public class MessageTaskDispatcher {

    Integer availableProcessors = Runtime.getRuntime().availableProcessors();
    Integer numOfThreads = availableProcessors * 2;


    private Collection<MessageTaskHandler> handlers;

    private Map<String, ExecutorService> executorServiceMap = new ConcurrentHashMap<>();

    private MsgTaskService msgTaskService;

    private MsgApplicationChannelService msgApplicationChannelService;

    private ExecutorService websocketPool = new ThreadPoolExecutor(availableProcessors, numOfThreads, 0L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<Runnable>(1024), new ThreadFactoryBuilder()
            .setNameFormat("websocket-pool-%d").get(), new ThreadPoolExecutor.AbortPolicy());
    private ExecutorService emailPool = new ThreadPoolExecutor(availableProcessors, numOfThreads, 0L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<Runnable>(1024), new ThreadFactoryBuilder()
            .setNameFormat("email-pool-%d").get(), new ThreadPoolExecutor.AbortPolicy());
    private ExecutorService pushPool = new ThreadPoolExecutor(availableProcessors, numOfThreads, 0L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<Runnable>(1024), new ThreadFactoryBuilder()
            .setNameFormat("push-pool-%d").get(), new ThreadPoolExecutor.AbortPolicy());
    private ExecutorService smsPool = new ThreadPoolExecutor(availableProcessors, numOfThreads, 0L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<Runnable>(1024), new ThreadFactoryBuilder()
            .setNameFormat("sms-pool-%d").get(), new ThreadPoolExecutor.AbortPolicy());


    public MessageTaskDispatcher(Collection<MessageTaskHandler> handlers, MsgTaskService msgTaskService, MsgApplicationChannelService msgApplicationChannelService) {
        this.handlers = handlers;
        this.msgTaskService = msgTaskService;
        this.msgApplicationChannelService = msgApplicationChannelService;
        executorServiceMap.put(MsgTaskType.PLATFORM.name(), websocketPool);
        executorServiceMap.put(MsgTaskType.EMAIL.name(), emailPool);
        executorServiceMap.put(MsgTaskType.PUSH.name(), pushPool);
        executorServiceMap.put(MsgTaskType.SMS.name(), smsPool);
    }

    public void dispatch(MessageTaskParams taskParams) throws Exception {
        if (taskParams != null && handlers != null) {
            Assert.hasText(taskParams.getAppId(), "appId不能为空");
            MsgTaskType taskType = MsgTaskType.getByValue(taskParams.getType());
            Assert.notNull(taskType, "任务类型不能为空");
            List<MsgChannelConfigVO> channelConfigList = msgApplicationChannelService.getChannelList(taskParams.getAppId());
            Assert.notEmpty(channelConfigList, "未设置任何通道:" + taskParams.getAppId());
            MessageTaskContext context = new MessageTaskContext();
            // 添加通道列表变量
            context.setVariable(VAR_CHANNEL_LIST,channelConfigList);
            Object params = taskParams.getParams();
            handlers.forEach((handler) -> {
                if (handler.support(params) && handler.validate(context, taskParams)) {
                    taskParams.setTaskId(IdWorker.getId());
                    ExecutorService executorService = executorServiceMap.get(taskType.name());
                    if (executorService != null) {
                        MsgTaskEntity entity = new MsgTaskEntity();
                        entity.setId(taskParams.getTaskId());
                        entity.setAppId(taskParams.getAppId());
                        entity.setDelayedTime(taskParams.getDelayedTime());
                        entity.setDelayedQueueRk(taskType.getDelayedQueueRK());
                        entity.setCreateTime(new Date());
                        entity.setStartTime(entity.getCreateTime());
                        entity.setStatus(MsgTaskStatus.CREATED.getValue());
                        entity.setParams(JSONObject.toJSONString(params));
                        entity.setType(taskParams.getType());
                        msgTaskService.save(entity);
                        log.info("任务创建:{}", entity.getId());
                        //使用全局线程池
                        executorService.submit(new MessageTaskCallable(context, handler, msgTaskService, taskParams));
                    }
                }
            });
        }

    }

}
